DIRICHLET_MULTINOM

Overview

The DIRICHLET_MULTINOM function computes statistical properties of the Dirichlet-multinomial distribution, a compound probability distribution that arises when category probabilities are uncertain. Also known as the Dirichlet compound multinomial (DCM) or multivariate Pólya distribution, it models scenarios where observations follow a multinomial distribution with probabilities drawn from a Dirichlet distribution.

This distribution is constructed by first drawing a probability vector \mathbf{p} from a Dirichlet distribution with concentration parameters \boldsymbol{\alpha} = (\alpha_1, \ldots, \alpha_K), then drawing counts from a multinomial distribution with n trials and probability vector \mathbf{p}. The probability mass function is:

P(\mathbf{x} \mid n, \boldsymbol{\alpha}) = \frac{\Gamma(\alpha_0) \Gamma(n+1)}{\Gamma(n + \alpha_0)} \prod_{k=1}^{K} \frac{\Gamma(x_k + \alpha_k)}{\Gamma(\alpha_k) \Gamma(x_k + 1)}

where \alpha_0 = \sum_{k=1}^{K} \alpha_k is the sum of concentration parameters, and \mathbf{x} = (x_1, \ldots, x_K) represents counts in each of K categories with \sum x_k = n.

The expected value for category i is E(X_i) = n \alpha_i / \alpha_0, and the variance is:

\text{Var}(X_i) = n \frac{\alpha_i}{\alpha_0} \left(1 - \frac{\alpha_i}{\alpha_0}\right) \frac{n + \alpha_0}{1 + \alpha_0}

The distribution exhibits overdispersion relative to the multinomial—the variance is inflated by a factor of (n + \alpha_0)/(1 + \alpha_0). This makes it suitable for modeling count data with extra variability, such as word frequencies in documents or allele counts in population genetics. The concentration parameter \alpha_0 controls the degree of overdispersion: smaller values produce greater variability, while larger values make the distribution approach a standard multinomial.

This implementation uses SciPy’s dirichlet_multinomial module and supports computing the PMF, log-PMF, mean, variance, and covariance matrix. For additional theoretical background, see the Wikipedia article on the Dirichlet-multinomial distribution.

This example function is provided as-is without any representation of accuracy.

Excel Usage

=DIRICHLET_MULTINOM(x, alpha, n, dm_method)
  • x (list[list], optional, default: null): 2D list of integer counts for each category. Required for pmf and logpmf methods.
  • alpha (list[list], optional, default: null): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution.
  • n (list[list], optional, default: null): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov.
  • dm_method (str, optional, default: “pmf”): Computation method to use.

Returns (list[list]): 2D list of results, or error message string.

Examples

Example 1: Basic PMF calculation with uniform concentration

Inputs:

x alpha n dm_method
2 3 5 1 1 1 10 pmf

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "pmf")

Expected output:

Result
0.0152

Example 2: Log-PMF calculation for same distribution

Inputs:

x alpha n dm_method
2 3 5 1 1 1 10 logpmf

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "logpmf")

Expected output:

Result
-4.1897

Example 3: Expected mean counts for weighted concentration

Inputs:

alpha n dm_method
2 3 5 10 mean

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {10}, "mean")

Expected output:

Result
2 3 5

Example 4: Variance for weighted concentration

Inputs:

alpha n dm_method
2 3 5 10 var

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {10}, "var")

Expected output:

Result
2.9091 3.8182 4.5455

Example 5: Covariance matrix for three categories

Inputs:

alpha dm_method
2 3 5 cov

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, "cov")

Expected output:

Result
0.16 -0.06 -0.1
-0.06 0.21 -0.15
-0.1 -0.15 0.25

Python Code

from scipy.stats import dirichlet_multinomial as scipy_dirichlet_multinomial

def dirichlet_multinom(x=None, alpha=None, n=None, dm_method='pmf'):
    """
    Computes the probability mass function, log probability mass function, mean, variance, or covariance of the Dirichlet multinomial distribution.

    See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet_multinomial.html

    This example function is provided as-is without any representation of accuracy.

    Args:
        x (list[list], optional): 2D list of integer counts for each category. Required for pmf and logpmf methods. Default is None.
        alpha (list[list], optional): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution. Default is None.
        n (list[list], optional): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov. Default is None.
        dm_method (str, optional): Computation method to use. Valid options: PMF, Log PMF, Mean, Variance, Covariance. Default is 'pmf'.

    Returns:
        list[list]: 2D list of results, or error message string.
    """
    def to2d(val):
      if val is None:
        return None
      return [[val]] if not isinstance(val, list) else val

    def to_float_list(arr):
      if hasattr(arr, 'tolist'):
        arr = arr.tolist()
      if isinstance(arr, (float, int)):
        return [float(arr)]
      return [float(v) for v in arr]

    valid_methods = {'pmf', 'logpmf', 'mean', 'var', 'cov'}
    if dm_method not in valid_methods:
      return f"Error: Invalid method '{dm_method}'. Must be one of {sorted(valid_methods)}."

    if alpha is None:
      return "Error: Invalid input: alpha is required."
    alpha = to2d(alpha)
    if not isinstance(alpha, list) or not all(isinstance(row, list) and len(row) > 0 for row in alpha):
      return "Error: alpha must be a 2D list of positive floats."
    if len(alpha) < 1:
      return "Error: alpha must have at least one row."

    try:
      alpha = [[float(v) for v in row] for row in alpha]
    except (TypeError, ValueError):
      return "alpha must be a 2D list of positive floats."
    if any(any(v <= 0 for v in row) for row in alpha):
      return "alpha must be a 2D list of positive floats."

    # n is required for pmf/logpmf/mean/var; for cov, default to n=1 if omitted
    if dm_method != 'cov':
      if n is None:
        return "Error: Invalid input: n is required."
      n = to2d(n)
      if not isinstance(n, list) or len(n) != len(alpha):
        return "Error: n must be a 2D list with the same number of rows as alpha."
      for n_row in n:
        if not isinstance(n_row, list) or len(n_row) != 1:
          return "Error: Each row of n must contain exactly one integer."
      try:
        n = [[int(val[0])] for val in n]
      except (TypeError, ValueError):
        return "Error: n must contain integers."
      if any(val[0] < 0 for val in n):
        return "Error: n must contain non-negative integers."
    else:
      if n is not None:
        n = to2d(n)
        if not isinstance(n, list) or len(n) != len(alpha):
          return "n must be a 2D list with the same number of rows as alpha."
        for n_row in n:
          if not isinstance(n_row, list) or len(n_row) != 1:
            return "Each row of n must contain exactly one integer."
        try:
          n = [[int(val[0])] for val in n]
        except (TypeError, ValueError):
          return "n must contain integers."
        if any(val[0] < 0 for val in n):
          return "n must contain non-negative integers."

    if dm_method in {'pmf', 'logpmf'}:
      if x is None:
        return "Error: Invalid input: x is required for pmf/logpmf."
      x = to2d(x)
      if not isinstance(x, list) or len(x) != len(alpha):
        return "Error: x must be a 2D list with the same number of rows as alpha."
      for row in x:
        if not isinstance(row, list) or len(row) != len(alpha[0]):
          return "Error: Each row of x must have the same length as alpha rows."
        try:
          if any(int(val) < 0 for val in row):
            return "Error: x must contain non-negative integers."
        except (TypeError, ValueError):
          return "Error: x must contain integers."

    results = []
    for i, alpha_row in enumerate(alpha):
      try:
        if dm_method == 'cov':
          n_val = 1 if n is None else n[i][0]
        else:
          n_val = n[i][0]

        if dm_method in {'pmf', 'logpmf'}:
          row_sum = sum(int(v) for v in x[i])
          if row_sum != n_val:
            return "Error: Invalid input: each row of x must sum to n."

        dist = scipy_dirichlet_multinomial(alpha=alpha_row, n=n_val)

        if dm_method == 'pmf':
          res = dist.pmf(x[i])
        elif dm_method == 'logpmf':
          res = dist.logpmf(x[i])
        elif dm_method == 'mean':
          res = dist.mean()
        elif dm_method == 'var':
          res = dist.var()
        elif dm_method == 'cov':
          res = dist.cov()

        if dm_method == 'cov':
          cov_matrix = res.tolist() if hasattr(res, 'tolist') else res
          for row in cov_matrix:
            results.append([float(val) for val in row])
        else:
          results.append(to_float_list(res))
      except Exception as e:
        return f"Error: computing {dm_method}: {e}"

    return results

Online Calculator